package mongox

import (
	"context"
	"fmt"
	"gitee.com/zhongguo168a/go-nodex/dbx"
	"gitee.com/zhongguo168a/gocodes/datax"
	"gitee.com/zhongguo168a/gocodes/datax/convertx"
	"gitee.com/zhongguo168a/gocodes/myx/errorx"
	"gitee.com/zhongguo168a/gocodes/myx/logx"
	"go.mongodb.org/mongo-driver/mongo"
	"go.mongodb.org/mongo-driver/mongo/options"
	"go.mongodb.org/mongo-driver/mongo/readpref"
	"stathat.com/c/consistent"
	"time"
)

type DBSubMode int

const (
	DBSubMode_None DBSubMode = iota
	// DBSubMode_Hash 水平拆分到不同的服务器
	DBSubMode_Hash
)

type Addr struct {
	IP   string
	Port string
	User string
	Pwd  string

	// 连接池
	client *mongo.Client
}

func (addr *Addr) GetUrl() string {
	if addr.User == "" {
		return fmt.Sprintf("mongodb://%v:%v", addr.IP, addr.Port)
	}
	return fmt.Sprintf("mongodb://%v:%v@%v:%v", addr.User, addr.Pwd, addr.IP, addr.Port)
}

func (addr *Addr) Close(ctx context.Context) (err error) {
	return addr.client.Disconnect(ctx)
}

func (addr *Addr) Ping() (err error) {
	conn, geterr := addr.GetConnection()
	if geterr != nil {
		return errorx.Wrap(geterr, "get connection")
	}
	return conn.Ping(nil, nil)
}

// 从连接池中获取地址为[addr]的连接
func (addr *Addr) GetConnection() (client *mongo.Client, err error) {
	if addr.client == nil {
		ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
		defer cancel()

		opt := options.Client().ApplyURI(addr.GetUrl())
		opt.SetMaxPoolSize(100)
		opt.SetMinPoolSize(10)
		client, err = mongo.Connect(ctx, opt)
		if err != nil {
			err = errorx.Wrap(err, "dial mongo")
			return nil, err
		}
		// 判断服务是不是可用
		if err = client.Ping(context.Background(), readpref.Primary()); err != nil {
			err = errorx.Wrap(err, "dial mongo")
			return nil, err
		}

		addr.client = client
	}

	return addr.client, nil
}

type Database struct {
	//
	Name string
	//
	Addrs []*Addr
	//
	SubMode DBSubMode
	//
	consistent *consistent.Consistent

	// 数据库名字的前缀
	Prefix string
}

func (c *Database) GetName() string {
	return c.Name
}

func (c *Database) getLastName() string {
	return c.Prefix + c.Name
}

func (c *Database) GetAddrIndex(subkey string) (int, error) {
	var index int
	if len(c.Addrs) == 1 {
		return 0, nil
	} else {
		idx, err := c.consistent.Get(subkey)
		if err != nil {
			return 0, errorx.Wrap(err, "hash", datax.M{"subkey": subkey})
		}
		index = convertx.AnyToInt(idx)
	}

	return index, nil
}

func (c *Database) GetAddr(subkey string) (*Addr, error) {
	var addr *Addr
	if len(c.Addrs) == 1 {
		addr = c.Addrs[0]
	} else {
		idx, err := c.consistent.Get(subkey)
		if err != nil {
			return nil, errorx.Wrap(err, "hash", datax.M{"subkey": subkey})
		}
		addr = c.Addrs[convertx.AnyToInt(idx)]
	}

	return addr, nil
}

func (c *Database) String() string {
	return fmt.Sprintf("MONGO[%v]", c.Name)
}

func (c *Database) SaveTables(items []*dbx.SaveItem) (fails []*dbx.SaveItem, err error) {
	if len(items) > 0 {
		firstItem := items[0]
		order := false
		req := NewRequestByName(firstItem.Table)
		bulks := func() (x []mongo.WriteModel) {
			for i := 0; i < len(items); i++ {
				item := items[i]
				switch item.Opt {
				case dbx.Opt_创建:
					update := item.GetDocumentByCreate()
					if update != nil && len(update) > 0 {
						model := mongo.NewUpdateOneModel()
						x = append(x, model)

						model.SetUpsert(true)
						model.SetFilter(item.GetQuery())
						model.SetUpdate(update)
					}
				case dbx.Opt_修改:
					update := item.GetDocumentByUpdate()
					if update != nil && len(update) > 0 {
						model := mongo.NewUpdateOneModel()
						x = append(x, model)

						model.SetUpsert(true)
						model.SetFilter(item.GetQuery())
						model.SetUpdate(update)
					}
				case dbx.Opt_删除:
					model := mongo.NewDeleteOneModel()
					x = append(x, model)
					model.SetFilter(item.GetQuery())
				}
			}
			return
		}()
		if len(bulks) == 0 {
			return
		}
		result, bulkerr := req.BulkWrite(bulks, &options.BulkWriteOptions{
			Ordered: &order,
		})
		_ = result

		//fmt.Printf("bulk result: %+v, %+v\n", result, bulkerr)
		if bulkerr != nil {
			if mongo.IsNetworkError(bulkerr) || mongo.IsTimeout(bulkerr) {
				return items, dbx.ErrDatabaseException
			}
			_, ok := bulkerr.(mongo.BulkWriteException)

			if ok {
				panic(errorx.Wrap(bulkerr, "SaveTables", datax.M{}))
				//var failItems []*dbx.SaveItem
				//for _, werr := range exp.WriteErrors {
				//	// 添加到错误处理队列
				//	failItems = append(failItems, items[werr.Index])
				//}
			} else {
				logx.Error(errorx.Wrap(bulkerr, "SaveTables: BulkWrite"))
				return items, bulkerr
			}
		}
	}
	return
}
